Data Exploration, Training Models, and XGBoost SHAP Values¶
Description¶
In this page, we present plots and graphs used for data exploration, training machine learning models and exploring the interpretability of the XGBoost model.
InĀ [1]:
# Imports
import numpy as np
import pandas as pd
df = pd.read_csv("author_sentiment.csv")
df.head(10)
Out[1]:
| TITLE | TARGET_ENTITY | DOCUMENT | TRUE_SENTIMENT | text_cleaned | text_processed | num_uppercase | num_first_pronoun | num_second_pronoun | num_third_pronoun | ... | white | win | woman | work | world | would | write | year | york | young | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | German bank LBBW wins EU bailout approval | Landesbank Baden Wuertemberg | Germany's Landesbank Baden Wuertemberg won EU ... | Negative | Germany's Landesbank Baden Wuertemberg won EU ... | Germany/NNP 's/POS Landesbank/NNP Baden/NNP Wu... | 2.0 | 0.0 | 0.0 | 8.0 | ... | 0.0 | 0.123543 | 0.000000 | 0.000000 | 0.000000 | 0.232421 | 0.000000 | 0.143785 | 0.0 | 0.0 |
| 1 | 8th LD Writethru: 9th passenger released from ... | Rolando Mendoza | The Philippine National Police (PNP) identifie... | Neutral | The Philippine National Police (PNP) identifie... | the/DT Philippine/NNP National/NNP Police/NNP ... | 5.0 | 0.0 | 0.0 | 5.0 | ... | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.0 |
| 2 | Commission: Bar Liberian president from office | Charles Taylor | Sirleaf 70 acknowledged before the commissio... | Negative | Sirleaf 70 acknowledged before the commission ... | sirleaf/NN 70/CD acknowledge/VBD before/IN the... | 0.0 | 1.0 | 0.0 | 16.0 | ... | 0.0 | 0.000000 | 0.000000 | 0.092898 | 0.000000 | 0.148777 | 0.000000 | 0.000000 | 0.0 | 0.0 |
| 3 | AP Exclusive: Network flaw causes scary Web error | Sawyers | Sawyer logged off and asked her sister Mari ... | Neutral | Sawyer logged off and asked her sister Mari 31... | Sawyer/NNP log/VBD off/RP and/CC ask/VBD her/P... | 0.0 | 0.0 | 0.0 | 15.0 | ... | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.179045 | 0.000000 | 0.0 | 0.0 |
| 4 | Holyfield ' s wife says boxer hit her several ... | Candi Holyfield | Candi Holyfield said in the protective order t... | Neutral | Candi Holyfield said in the protective order t... | Candi/NNP Holyfield/NNP say/VBD in/IN the/DT p... | 0.0 | 5.0 | 1.0 | 17.0 | ... | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.113763 | 0.000000 | 0.105568 | 0.0 | 0.0 |
| 5 | Hillary Clinton : Misogyny is ` endemic ' . | Hillary Clinton | -LRB- CNN -RRB- Hillary Clinton slammed what s... | Neutral | -LRB- CNN -RRB- Hillary Clinton slammed what s... | -LRB-/NNP CNN/NNP -RRB-/NNP Hillary/NNP Clinto... | 4.0 | 0.0 | 0.0 | 5.0 | ... | 0.0 | 0.000000 | 0.250242 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.0 |
| 6 | Trouser-wearing women fined $200 in Sudan | Lubna Hussein | Lubna Hussein was among 13 women arrested July... | Neutral | Lubna Hussein was among 13 women arrested July... | Lubna/NNP Hussein/NNP be/VBD among/IN 13/CD wo... | 1.0 | 0.0 | 0.0 | 14.0 | ... | 0.0 | 0.000000 | 0.344940 | 0.000000 | 0.224427 | 0.139241 | 0.000000 | 0.000000 | 0.0 | 0.0 |
| 7 | Hillary Clinton Compares Donald Trump With Har... | Hillary Clinton | "A lot of people thought I was probably exagge... | Neutral | "A lot of people thought I was probably exagge... | "/`` A/DT lot/NN of/IN people/NNS think/VBD I/... | 0.0 | 2.0 | 0.0 | 3.0 | ... | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.0 |
| 8 | Feature: "Chinese is an important part of my l... | Maria Rukodelnikova | Rukodelnikova is fond of a lot things from Chi... | Positive | Rukodelnikova is fond of a lot things from Chi... | Rukodelnikova/NNP be/VBZ fond/JJ of/IN a/DT lo... | 0.0 | 0.0 | 0.0 | 6.0 | ... | 0.0 | 0.000000 | 0.000000 | 0.183372 | 0.000000 | 0.000000 | 0.000000 | 0.136258 | 0.0 | 0.0 |
| 9 | Former Australian Opposition leader attacks ne... | Tony Abbott | Former Australian Opposition leader Malcolm Tu... | Neutral | Former Australian Opposition leader Malcolm Tu... | former/JJ Australian/NNP Opposition/NNP leader... | 2.0 | 2.0 | 0.0 | 8.0 | ... | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.0 |
10 rows Ć 223 columns
InĀ [2]:
# Configuration
from traitlets.config import Config
import nbformat as nbf
from nbconvert.exporters import HTMLExporter
from nbconvert.preprocessors import TagRemovePreprocessor
# Setup config
c = Config()
# Configure tag removal - be sure to tag your cells to remove using the
# words remove_cell to remove cells. You can also modify the code to use
# a different tag word
c.TagRemovePreprocessor.remove_cell_tags = ("remove_cell",)
c.TagRemovePreprocessor.remove_all_outputs_tags = ("remove_output",)
c.TagRemovePreprocessor.remove_input_tags = ("remove_input",)
c.TagRemovePreprocessor.enabled = True
# Configure and run out exporter
c.HTMLExporter.preprocessors = ["nbconvert.preprocessors.TagRemovePreprocessor"]
exporter = HTMLExporter(config=c)
exporter.register_preprocessor(TagRemovePreprocessor(config=c), True)
# Configure and run our exporter - returns a tuple - first element with html,
# second with notebook metadata
output = HTMLExporter(config=c).from_filename("ml_models.ipynb")
# Write to output html file
with open("ml_models.html", "w") as f:
f.write(output[0])
InĀ [3]:
from sklearn.preprocessing import LabelEncoder
label_encoder = LabelEncoder()
df['label'] = label_encoder.fit_transform(df['TRUE_SENTIMENT'].values) #neg = 0, #neu = 1, pos=2
# print(df.head(5))
# Fix class imbalance
class_count_0, class_count_1, class_count_2 = df['label'].value_counts()
class_0 = df[df['label'] == 0] # neg
class_1 = df[df['label'] == 1] # neu
class_2 = df[df['label'] == 2] # pos
print('class 0:', class_0.shape)
print('class 1:', class_1.shape)
print('class 2:', class_2.shape)
print(class_count_2)
class_0_over = class_0.sample(class_count_1, replace=True, ignore_index=True)
class_2_under = class_2.sample(class_count_1, ignore_index=True)
new_df = pd.concat([class_1, class_0_over, class_2_under], axis=0, ignore_index=True)
print("total class of 0, 1 and 2:", new_df['label'].value_counts()) # plot the count after under-sampeling
new_df['label'].value_counts().plot(kind='bar', title='Number of each class')
class 0: (351, 224) class 1: (1246, 224) class 2: (1758, 224) 351 total class of 0, 1 and 2: label 1 1246 0 1246 2 1246 Name: count, dtype: int64
Out[3]:
<Axes: title={'center': 'Number of each class'}, xlabel='label'>
Correlation Heat Map for Features 1 - 17 and Target Label¶
InĀ [4]:
# Correlation matrix
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
# Calculate the correlation matrix
correlation_matrix = new_df.iloc[:, [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, -1]].corr()
# Plot the heatmap
plt.figure(figsize=(15, 15))
sns.heatmap(correlation_matrix, annot=False, cmap='coolwarm', fmt=".2f", annot_kws={"size": 10})
plt.title('Correlation Heatmap')
plt.show()
Scatter Plot of Average Length of Token by Each Sentiment Class¶
InĀ [5]:
# Visualization I
# !pip install plotly
import plotly.express as px
# plotting the bubble chart
fig = px.scatter(new_df, x="TRUE_SENTIMENT", y="avg_len_token",
size= "avg_len_sen", color="TRUE_SENTIMENT")
fig.update_layout(title='Scatter plot of average length of token by each sentiment class',
xaxis_title='Sentiment Class',
yaxis_title='Average Length Token')
# showing the plot
fig.show()
Scatter Plot of Number of Proper Nouns by Each Sentiment Class¶
InĀ [6]:
# Visualization II
!pip install plotly
import plotly.express as px
# plotting the bubble chart
fig = px.scatter(new_df, x="TRUE_SENTIMENT", y="num_proper_noun",
size= "avg_len_sen", color="TRUE_SENTIMENT")
fig.update_layout(title='Scatter plot of number of proper nouns by each sentiment class',
xaxis_title='Sentiment Class',
yaxis_title='Number of Proper Nouns')
# showing the plot
fig.show()
Requirement already satisfied: plotly in /opt/conda/lib/python3.11/site-packages (5.20.0) Requirement already satisfied: tenacity>=6.2.0 in /opt/conda/lib/python3.11/site-packages (from plotly) (8.2.3) Requirement already satisfied: packaging in /opt/conda/lib/python3.11/site-packages (from plotly) (23.2)
Scatter Plot of Average Length of Sentence by Each Sentiment Class¶
InĀ [7]:
# Visualization III
!pip install plotly
import plotly.express as px
# plotting the bubble chart
fig = px.scatter(new_df, x="TRUE_SENTIMENT", y="num_past_verb",
size= "avg_len_sen", color="TRUE_SENTIMENT")
fig.update_layout(title='Scatter plot of number of past-tense verb by each sentiment class',
xaxis_title='Sentiment Class',
yaxis_title='Number of Past-Tense Verb')
# showing the plot
fig.show()
Requirement already satisfied: plotly in /opt/conda/lib/python3.11/site-packages (5.20.0) Requirement already satisfied: tenacity>=6.2.0 in /opt/conda/lib/python3.11/site-packages (from plotly) (8.2.3) Requirement already satisfied: packaging in /opt/conda/lib/python3.11/site-packages (from plotly) (23.2)
InĀ [8]:
# Separate feat and label - data process
feat = new_df.iloc[:, 6:-2]
feat_names = list(feat.columns)
label = new_df.loc[:, ['label']]
print(new_df)
print(feat.shape)
print(label.shape)
print(feat_names)
TITLE TARGET_ENTITY \
0 8th LD Writethru: 9th passenger released from ... Rolando Mendoza
1 AP Exclusive: Network flaw causes scary Web error Sawyers
2 Holyfield ' s wife says boxer hit her several ... Candi Holyfield
3 Hillary Clinton : Misogyny is ` endemic ' . Hillary Clinton
4 Trouser-wearing women fined $200 in Sudan Lubna Hussein
... ... ...
3733 Thousands travel to Yangon for Pope's diplomat... Pope Francis
3734 Er the Oscars' In Memoriam section showed a wo... Jan Chapman
3735 Big cat: Jack needs new home and a diet Jack Jack
3736 Woman accuses George H.W. Bush of groping her ... H. W. Bush
3737 Interview: Language cornerstone of Russia-Chin... Elizabeth Pavlova
DOCUMENT TRUE_SENTIMENT \
0 The Philippine National Police (PNP) identifie... Neutral
1 Sawyer logged off and asked her sister Mari ... Neutral
2 Candi Holyfield said in the protective order t... Neutral
3 -LRB- CNN -RRB- Hillary Clinton slammed what s... Neutral
4 Lubna Hussein was among 13 women arrested July... Neutral
... ... ...
3733 YANGON (Reuters) - Thousands of Catholics gath... Positive
3734 Advertisement - Continue Reading Below\nAustra... Positive
3735 Skip in Skip x Embed x Share CLOSE Jack the ... Positive
3736 Another woman has come forward to accuse form... Positive
3737 "It's important for Russia and China to cultiv... Positive
text_cleaned \
0 The Philippine National Police (PNP) identifie...
1 Sawyer logged off and asked her sister Mari 31...
2 Candi Holyfield said in the protective order t...
3 -LRB- CNN -RRB- Hillary Clinton slammed what s...
4 Lubna Hussein was among 13 women arrested July...
... ...
3733 YANGON (Reuters) - Thousands of Catholics gath...
3734 Advertisement - Continue Reading Below Austral...
3735 Skip in Skip x Embed x Share CLOSE Jack the 30...
3736 Another woman has come forward to accuse forme...
3737 "It's important for Russia and China to cultiv...
text_processed num_uppercase \
0 the/DT Philippine/NNP National/NNP Police/NNP ... 5.0
1 Sawyer/NNP log/VBD off/RP and/CC ask/VBD her/P... 0.0
2 Candi/NNP Holyfield/NNP say/VBD in/IN the/DT p... 0.0
3 -LRB-/NNP CNN/NNP -RRB-/NNP Hillary/NNP Clinto... 4.0
4 Lubna/NNP Hussein/NNP be/VBD among/IN 13/CD wo... 1.0
... ... ...
3733 YANGON/NNP (/-LRB- Reuters/NNP )/-RRB- -/: tho... 8.0
3734 advertisement/NN -/: continue/VB read/VBG belo... 0.0
3735 skip/VB in/IN Skip/NNP x/SYM Embed/NNP x/NN \n... 2.0
3736 another/DT woman/NN have/VBZ come/VBN forward/... 3.0
3737 "/`` it/PRP be/VBZ important/JJ for/IN Russia/... 0.0
num_first_pronoun num_second_pronoun num_third_pronoun ... win \
0 0.0 0.0 5.0 ... 0.0
1 0.0 0.0 15.0 ... 0.0
2 5.0 1.0 17.0 ... 0.0
3 0.0 0.0 5.0 ... 0.0
4 0.0 0.0 14.0 ... 0.0
... ... ... ... ... ...
3733 0.0 0.0 18.0 ... 0.0
3734 5.0 0.0 8.0 ... 0.0
3735 0.0 0.0 6.0 ... 0.0
3736 2.0 0.0 27.0 ... 0.0
3737 0.0 0.0 8.0 ... 0.0
woman work world would write year york \
0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.0
1 0.000000 0.000000 0.000000 0.000000 0.179045 0.000000 0.0
2 0.000000 0.000000 0.000000 0.113763 0.000000 0.105568 0.0
3 0.250242 0.000000 0.000000 0.000000 0.000000 0.000000 0.0
4 0.344940 0.000000 0.224427 0.139241 0.000000 0.000000 0.0
... ... ... ... ... ... ... ...
3733 0.133844 0.000000 0.000000 0.081042 0.000000 0.000000 0.0
3734 0.000000 0.312165 0.000000 0.000000 0.000000 0.000000 0.0
3735 0.000000 0.000000 0.000000 0.107136 0.000000 0.198836 0.0
3736 0.388910 0.000000 0.000000 0.000000 0.000000 0.000000 0.0
3737 0.000000 0.000000 0.000000 0.166829 0.087332 0.361226 0.0
young label
0 0.000000 1
1 0.000000 1
2 0.000000 1
3 0.000000 1
4 0.000000 1
... ... ...
3733 0.000000 2
3734 0.000000 2
3735 0.000000 2
3736 0.000000 2
3737 0.101125 2
[3738 rows x 224 columns]
(3738, 216)
(3738, 1)
['num_uppercase', 'num_first_pronoun', 'num_second_pronoun', 'num_third_pronoun', 'num_coord_conj', 'num_past_verb', 'num_future_verb', 'num_comma', 'num_multi_punc', 'num_common_noun', 'num_proper_noun', 'num_adverb', 'num_wh', 'num_slang', 'avg_len_sen', 'avg_len_token', 'num_sen', '000', '10', '2016', '2017', 'accord', 'add', 'allegation', 'also', 'american', 'another', 'around', 'ask', 'back', 'become', 'begin', 'believe', 'big', 'bill', 'bush', 'business', 'call', 'campaign', 'case', 'change', 'charge', 'child', 'city', 'claim', 'clinton', 'close', 'come', 'company', 'continue', 'could', 'country', 'court', 'cruz', 'day', 'deal', 'donald', 'early', 'election', 'end', 'even', 'face', 'family', 'far', 'feel', 'find', 'fire', 'first', 'follow', 'force', 'former', 'four', 'game', 'get', 'give', 'go', 'good', 'government', 'great', 'group', 'head', 'help', 'high', 'hold', 'home', 'house', 'include', 'interview', 'issue', 'itâ', 'james', 'job', 'keep', 'know', 'last', 'late', 'later', 'law', 'lead', 'leader', 'leave', 'life', 'like', 'live', 'long', 'look', 'lose', 'lot', 'lsb', 'make', 'man', 'many', 'may', 'medium', 'meet', 'meeting', 'member', 'million', 'monday', 'month', 'move', 'much', 'name', 'national', 'need', 'never', 'new', 'news', 'next', 'night', 'north', 'office', 'official', 'old', 'one', 'open', 'part', 'party', 'pay', 'people', 'percent', 'photo', 'place', 'plan', 'play', 'point', 'police', 'political', 'post', 'president', 'public', 'put', 'question', 'really', 'release', 'report', 'republican', 'return', 'reuters', 'right', 'romney', 'rsb', 'run', 'sanders', 'say', 'school', 'season', 'second', 'see', 'senate', 'set', 'sexual', 'share', 'show', 'since', 'speak', 'start', 'state', 'statement', 'still', 'story', 'support', 'take', 'talk', 'team', 'tell', 'thing', 'think', 'though', 'three', 'thursday', 'time', 'top', 'trump', 'try', 'tuesday', 'turn', 'two', 'united', 'use', 'vote', 'want', 'washington', 'way', 'wednesday', 'week', 'well', 'white', 'win', 'woman', 'work', 'world', 'would', 'write', 'year', 'york']
InĀ [9]:
# Train-Test split
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
feat, label, test_size=0.2, random_state=1100)
print(X_train)
print(X_test)
print(y_train)
print(y_test)
num_uppercase num_first_pronoun num_second_pronoun num_third_pronoun \
3212 1.0 5.0 8.0 14.0
325 0.0 0.0 0.0 7.0
1255 0.0 0.0 0.0 1.0
454 0.0 0.0 0.0 7.0
645 0.0 2.0 2.0 10.0
... ... ... ... ...
968 2.0 0.0 0.0 9.0
2288 0.0 5.0 2.0 27.0
2991 4.0 4.0 0.0 14.0
2093 2.0 0.0 1.0 5.0
3462 1.0 1.0 0.0 2.0
num_coord_conj num_past_verb num_future_verb num_comma \
3212 30.0 14.0 2.0 0.0
325 3.0 6.0 0.0 0.0
1255 3.0 1.0 0.0 3.0
454 3.0 6.0 0.0 0.0
645 4.0 6.0 0.0 0.0
... ... ... ... ...
968 4.0 10.0 0.0 0.0
2288 14.0 16.0 3.0 0.0
2991 14.0 24.0 2.0 0.0
2093 3.0 16.0 0.0 0.0
3462 4.0 2.0 0.0 0.0
num_multi_punc num_common_noun ... well white win \
3212 61.0 183.0 ... 0.071516 0.0 0.000000
325 11.0 38.0 ... 0.000000 0.0 0.000000
1255 6.0 8.0 ... 0.000000 0.0 0.271660
454 11.0 38.0 ... 0.000000 0.0 0.000000
645 12.0 24.0 ... 0.142114 0.0 0.000000
... ... ... ... ... ... ...
968 15.0 43.0 ... 0.000000 0.0 0.000000
2288 52.0 85.0 ... 0.086881 0.0 0.106739
2991 49.0 91.0 ... 0.067443 0.0 0.082859
2093 21.0 57.0 ... 0.000000 0.0 0.000000
3462 8.0 21.0 ... 0.107220 0.0 0.000000
woman work world would write year york
3212 0.000000 0.00000 0.000000 0.055099 0.000000 0.000000 0.093510
325 0.000000 0.00000 0.000000 0.000000 0.000000 0.315297 0.000000
1255 0.000000 0.00000 0.000000 0.000000 0.000000 0.000000 0.000000
454 0.000000 0.00000 0.000000 0.000000 0.000000 0.315297 0.000000
645 0.000000 0.00000 0.000000 0.000000 0.000000 0.000000 0.000000
... ... ... ... ... ... ... ...
968 0.000000 0.00000 0.000000 0.415975 0.000000 0.000000 0.000000
2288 0.110547 0.00000 0.000000 0.066936 0.000000 0.000000 0.000000
2991 0.000000 0.06489 0.000000 0.000000 0.081601 0.000000 0.088184
2093 0.000000 0.00000 0.126974 0.000000 0.000000 0.000000 0.000000
3462 0.000000 0.00000 0.000000 0.165213 0.000000 0.000000 0.000000
[2990 rows x 216 columns]
num_uppercase num_first_pronoun num_second_pronoun num_third_pronoun \
3655 5.0 6.0 1.0 32.0
743 1.0 0.0 0.0 14.0
445 0.0 5.0 0.0 10.0
558 5.0 4.0 0.0 13.0
1930 6.0 6.0 3.0 41.0
... ... ... ... ...
1263 0.0 1.0 1.0 30.0
1749 2.0 0.0 0.0 6.0
2754 9.0 4.0 2.0 17.0
2080 1.0 0.0 0.0 1.0
3255 0.0 6.0 1.0 11.0
num_coord_conj num_past_verb num_future_verb num_comma \
3655 19.0 13.0 3.0 0.0
743 5.0 16.0 0.0 0.0
445 10.0 14.0 1.0 0.0
558 8.0 25.0 0.0 0.0
1930 16.0 40.0 2.0 0.0
... ... ... ... ...
1263 7.0 35.0 0.0 0.0
1749 7.0 17.0 0.0 0.0
2754 8.0 17.0 1.0 0.0
2080 1.0 7.0 0.0 6.0
3255 10.0 20.0 1.0 0.0
num_multi_punc num_common_noun ... well white win \
3655 52.0 84.0 ... 0.000000 0.06664 0.000000
743 17.0 53.0 ... 0.000000 0.00000 0.000000
445 31.0 71.0 ... 0.000000 0.00000 0.000000
558 38.0 100.0 ... 0.000000 0.00000 0.000000
1930 107.0 135.0 ... 0.056953 0.00000 0.000000
... ... ... ... ... ... ...
1263 57.0 73.0 ... 0.000000 0.00000 0.000000
1749 15.0 46.0 ... 0.000000 0.00000 0.000000
2754 34.0 48.0 ... 0.000000 0.00000 0.000000
2080 19.0 13.0 ... 0.000000 0.00000 0.000000
3255 60.0 40.0 ... 0.173591 0.00000 0.319905
woman work world would write year york
3655 0.000000 0.000000 0.0 0.040776 0.000000 0.000000 0.000000
743 0.000000 0.000000 0.0 0.000000 0.000000 0.000000 0.000000
445 0.000000 0.000000 0.0 0.130476 0.000000 0.000000 0.000000
558 0.000000 0.097771 0.0 0.156582 0.000000 0.000000 0.000000
1930 0.289869 0.054797 0.0 0.043879 0.000000 0.162872 0.000000
... ... ... ... ... ... ... ...
1263 0.108881 0.000000 0.0 0.000000 0.000000 0.244712 0.223773
1749 0.506746 0.000000 0.0 0.076709 0.000000 0.071183 0.000000
2754 0.000000 0.083497 0.0 0.000000 0.000000 0.000000 0.000000
2080 0.000000 0.000000 0.0 0.000000 0.110757 0.000000 0.000000
3255 0.000000 0.083510 0.0 0.133742 0.000000 0.000000 0.000000
[748 rows x 216 columns]
label
3212 2
325 1
1255 0
454 1
645 1
... ...
968 1
2288 0
2991 2
2093 0
3462 2
[2990 rows x 1 columns]
label
3655 2
743 1
445 1
558 1
1930 0
... ...
1263 0
1749 0
2754 2
2080 0
3255 2
[748 rows x 1 columns]
InĀ [10]:
# Train Xgboost - default
!pip install xgboost
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
import xgboost as xgb
# Init classifier
xgb_cl = xgb.XGBClassifier()
# Fit
xgb_cl.fit(X_train, y_train)
# Predict
preds = xgb_cl.predict(X_test)
accuracy = accuracy_score(y_test, preds)
f1 = f1_score(y_test, preds, average='macro')
# Score
print("F1 Score:", f1)
print("Accuracy:", accuracy)
Requirement already satisfied: xgboost in /opt/conda/lib/python3.11/site-packages (2.0.3) Requirement already satisfied: numpy in /opt/conda/lib/python3.11/site-packages (from xgboost) (1.24.4) Requirement already satisfied: scipy in /opt/conda/lib/python3.11/site-packages (from xgboost) (1.11.4) F1 Score: 0.6538556985903352 Accuracy: 0.6564171122994652
InĀ [11]:
# Grid search for Xgboost
from sklearn.model_selection import GridSearchCV
# Define the hyperparameter grid
param_grid = {
'max_depth': [17, 20, 22], # 6, 7, 8, 10, 12, 15, 17, 20, 22
'learning_rate': [0.02, 0.03], # 0.008, 0.009, 0.01, 0.02,0.03
'subsample': [0.5] # 0.5, 0.6
}
# Create the XGBoost model object
xgb_model = xgb.XGBClassifier()
# Create the GridSearchCV object
grid_search = GridSearchCV(xgb_model, param_grid, cv=5, scoring='accuracy')
# Fit the GridSearchCV object to the training data
grid_search.fit(X_train, y_train)
# Print the best set of hyperparameters and the corresponding score
print("Best set of hyperparameters: ", grid_search.best_params_)
print("Best score: ", grid_search.best_score_) # 0.6749163879598663
Best set of hyperparameters: {'learning_rate': 0.03, 'max_depth': 20, 'subsample': 0.5}
Best score: 0.6678929765886288
InĀ [12]:
# Train logistic regression
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score
# Initialize the logistic regression model
model = LogisticRegression()
# Train the model
model.fit(X_train, y_train)
# Predict on the testing set
y_pred = model.predict(X_test)
# Evaluate the model
accuracy = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred, average='macro')
print("F1 Score:", f1)
print("Accuracy:", accuracy)
/opt/conda/lib/python3.11/site-packages/sklearn/utils/validation.py:1183: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
F1 Score: 0.431597278233198 Accuracy: 0.43716577540106955
/opt/conda/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning:
lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.
Increase the number of iterations (max_iter) or scale the data as shown in:
https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
InĀ [13]:
# Train Decision Tree
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import f1_score
# Initialize the decision tree classifierd
dt_model = DecisionTreeClassifier()
# Train the model
dt_model.fit(X_train, y_train)
# Predict on the testing set
y_pred = dt_model.predict(X_test)
# Evaluate the model
accuracy = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred, average='macro')
print("F1 Score:", f1)
print("Accuracy:", accuracy)
F1 Score: 0.6109020781073954 Accuracy: 0.6216577540106952
InĀ [14]:
# Save best performing model
from sklearn.metrics import f1_score
model_file_path = 'xgboost_model.bin'
xgb_model = xgb.XGBClassifier(max_depth = 22, learning_rate = 0.02, subsample = 0.5)
xgb_model.fit(X_train, y_train)
preds = xgb_model.predict(X_test)
f1 = f1_score(y_test, preds, average='macro')
print("F1 Score:", f1)
print('Accuracy:', accuracy_score(y_test, preds))
# Save the trained model
xgb_model.save_model(model_file_path)
F1 Score: 0.6805725726632295 Accuracy: 0.6831550802139037
/opt/conda/lib/python3.11/site-packages/xgboost/core.py:160: UserWarning: [20:43:26] WARNING: /workspace/src/c_api/c_api.cc:1240: Saving into deprecated binary model format, please consider using `json` or `ubj`. Model format will default to JSON in XGBoost 2.2 if not specified.
Sentiment Analysis of ChatGPT's Response¶
InĀ [15]:
# Predict on gpt's response
gpt3_df = pd.read_csv("response_gpt3.csv")
gpt3_df = gpt3_df.iloc[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 18], :-1]
gpt4_df = pd.read_csv("response_gpt4.csv")
gpt4_df = gpt4_df.iloc[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 18], :-1]
preds_gpt3 = xgb_model.predict(gpt3_df) # make prediciton on gpt3's response
preds_gpt4 = xgb_model.predict(gpt4_df) # make prediction on gpt4's response
gpt3_df['predicted_sentiment'] = label_encoder.inverse_transform(preds_gpt3)
gpt4_df['predicted_sentiment'] = label_encoder.inverse_transform(preds_gpt4)
print(gpt3_df)
print(gpt4_df)
num_uppercase num_first_pronoun num_second_pronoun num_third_pronoun \
1 0.0 2.0 0.0 0.0
2 1.0 0.0 0.0 12.0
3 1.0 0.0 0.0 10.0
4 0.0 0.0 0.0 6.0
5 0.0 1.0 0.0 11.0
6 0.0 0.0 0.0 20.0
7 2.0 0.0 0.0 12.0
8 0.0 0.0 0.0 3.0
9 4.0 0.0 0.0 5.0
10 0.0 4.0 0.0 3.0
17 2.0 0.0 0.0 4.0
18 0.0 0.0 0.0 7.0
num_coord_conj num_past_verb num_future_verb num_comma num_multi_punc \
1 10.0 2.0 0.0 18.0 37.0
2 7.0 2.0 3.0 12.0 33.0
3 13.0 0.0 1.0 9.0 24.0
4 10.0 2.0 0.0 6.0 24.0
5 8.0 6.0 0.0 14.0 25.0
6 13.0 7.0 0.0 10.0 23.0
7 12.0 10.0 1.0 16.0 29.0
8 9.0 2.0 0.0 8.0 22.0
9 12.0 5.0 0.0 9.0 22.0
10 10.0 0.0 0.0 14.0 30.0
17 6.0 0.0 0.0 9.0 29.0
18 10.0 5.0 0.0 7.0 26.0
num_common_noun ... white win woman work world would \
1 86.0 ... 0.0 0.0 0.0 0.104915 0.000000 0.000000
2 57.0 ... 0.0 0.0 0.0 0.000000 0.000000 0.000000
3 71.0 ... 0.0 0.0 0.0 0.000000 0.000000 0.000000
4 66.0 ... 0.0 0.0 0.0 0.191973 0.000000 0.000000
5 56.0 ... 0.0 0.0 0.0 0.129537 0.000000 0.000000
6 65.0 ... 0.0 0.0 0.0 0.000000 0.000000 0.000000
7 62.0 ... 0.0 0.0 0.0 0.000000 0.165213 0.000000
8 64.0 ... 0.0 0.0 0.0 0.000000 0.000000 0.099061
9 88.0 ... 0.0 0.0 0.0 0.000000 0.000000 0.000000
10 71.0 ... 0.0 0.0 0.0 0.000000 0.403363 0.750776
17 68.0 ... 0.0 0.0 0.0 0.000000 0.000000 0.000000
18 66.0 ... 0.0 0.0 0.0 0.000000 0.000000 0.000000
write year york predicted_sentiment
1 0.0 0.000000 0.0 Negative
2 0.0 0.000000 0.0 Neutral
3 0.0 0.000000 0.0 Negative
4 0.0 0.000000 0.0 Negative
5 0.0 0.000000 0.0 Negative
6 0.0 0.000000 0.0 Negative
7 0.0 0.000000 0.0 Neutral
8 0.0 0.091925 0.0 Negative
9 0.0 0.000000 0.0 Negative
10 0.0 0.000000 0.0 Negative
17 0.0 0.000000 0.0 Negative
18 0.0 0.000000 0.0 Negative
[12 rows x 217 columns]
num_uppercase num_first_pronoun num_second_pronoun num_third_pronoun \
1 0.0 0.0 1.0 4.0
2 2.0 1.0 0.0 9.0
3 2.0 1.0 0.0 5.0
4 1.0 4.0 0.0 9.0
5 3.0 1.0 0.0 15.0
6 0.0 3.0 0.0 14.0
7 4.0 0.0 0.0 17.0
8 0.0 3.0 0.0 0.0
9 2.0 2.0 0.0 5.0
10 0.0 4.0 0.0 8.0
17 3.0 0.0 0.0 4.0
18 1.0 0.0 0.0 10.0
num_coord_conj num_past_verb num_future_verb num_comma num_multi_punc \
1 9.0 5.0 0.0 19.0 40.0
2 2.0 3.0 2.0 12.0 28.0
3 8.0 2.0 1.0 13.0 29.0
4 11.0 6.0 0.0 9.0 27.0
5 6.0 15.0 0.0 15.0 36.0
6 10.0 18.0 0.0 20.0 48.0
7 10.0 15.0 1.0 15.0 40.0
8 5.0 10.0 0.0 15.0 32.0
9 13.0 10.0 1.0 11.0 25.0
10 12.0 4.0 0.0 21.0 43.0
17 6.0 2.0 2.0 10.0 27.0
18 9.0 6.0 1.0 15.0 36.0
num_common_noun ... white win woman work world would \
1 78.0 ... 0.0 0.0 0.0 0.105211 0.135790 0.000000
2 54.0 ... 0.0 0.0 0.0 0.000000 0.000000 0.000000
3 71.0 ... 0.0 0.0 0.0 0.000000 0.000000 0.000000
4 62.0 ... 0.0 0.0 0.0 0.175081 0.000000 0.000000
5 58.0 ... 0.0 0.0 0.0 0.000000 0.000000 0.000000
6 69.0 ... 0.0 0.0 0.0 0.117525 0.000000 0.000000
7 48.0 ... 0.0 0.0 0.0 0.000000 0.000000 0.000000
8 60.0 ... 0.0 0.0 0.0 0.000000 0.000000 0.000000
9 65.0 ... 0.0 0.0 0.0 0.000000 0.000000 0.000000
10 73.0 ... 0.0 0.0 0.0 0.473087 0.203529 0.252551
17 55.0 ... 0.0 0.0 0.0 0.000000 0.000000 0.000000
18 56.0 ... 0.0 0.0 0.0 0.000000 0.000000 0.000000
write year york predicted_sentiment
1 0.000000 0.000000 0.142979 Positive
2 0.000000 0.000000 0.000000 Negative
3 0.000000 0.000000 0.000000 Neutral
4 0.000000 0.000000 0.000000 Negative
5 0.000000 0.000000 0.000000 Negative
6 0.147791 0.174658 0.000000 Positive
7 0.000000 0.000000 0.000000 Positive
8 0.000000 0.071425 0.000000 Negative
9 0.000000 0.000000 0.000000 Neutral
10 0.000000 0.000000 0.000000 Positive
17 0.000000 0.000000 0.000000 Negative
18 0.000000 0.058539 0.000000 Neutral
[12 rows x 217 columns]
SHAP Analysis of XGBoost¶
InĀ [17]:
# Conduct SHAP analysis
# !pip install shap
# !pip install xgboost
import shap
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
import xgboost as xgb
# print the JS visualization code to the notebook
shap.initjs()
# model_file_path = 'xgboost_model.bin'
# loaded_model = xgb.Booster()
# loaded_model.load_model(model_file_path)
# preds = loaded_model.predict(X_test)
# Shap values to see the feature importance
explainer = shap.TreeExplainer(xgb_model)
shap_values = explainer.shap_values(X_test, check_additivity=False)
print(shap_values.shape)
# print(explainer.expected_value)
shap.force_plot(explainer.expected_value[0], shap_values[:, :, 0])
(748, 216, 3)
Out[17]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
InĀ [18]:
shap.summary_plot(shap_values[:, :, 0], features=feat, plot_type="bar") # class: negative
InĀ [19]:
shap.summary_plot(shap_values[:, :, 1], features=feat, plot_type="bar") # class: neutral
InĀ [20]:
shap.summary_plot(shap_values[:, :, 2], features=feat, plot_type="bar") # class: positive
InĀ [21]:
shap.force_plot(explainer.expected_value[1], shap_values[0, :, 1], features=feat.iloc[0, :]) # class: neutral
Out[21]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.